Defining Custom Builders
Following is an example defining a new builder for creating a simple fully-connected neural network with two hidden layers, with n1
nodes in the first hidden layer, and n2
nodes in the second, for use in any of the first three models in Table 1. The definition includes one mutable struct and one method:
mutable struct MyBuilder <: MLJFlux.Builder
n1 :: Int
n2 :: Int
end
function MLJFlux.build(nn::MyBuilder, rng, n_in, n_out)
init = Flux.glorot_uniform(rng)
return Chain(
Dense(n_in, nn.n1, init=init),
Dense(nn.n1, nn.n2, init=init),
Dense(nn.n2, n_out, init=init),
)
end
Note here that n_in
and n_out
depend on the size of the data (see Table 1).
For a concrete image classification example, see Using MLJ to classifiy the MNIST image dataset.
More generally, defining a new builder means defining a new struct sub-typing MLJFlux.Builder
and defining a new MLJFlux.build
method with one of these signatures:
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out)
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out, n_channels) # for use with `ImageClassifier`
This method must return a Flux.Chain
instance, chain
, subject to the following conditions:
chain(x)
must make sense:- for any
x <: Array{<:AbstractFloat, 2}
of size(n_in, batch_size)
wherebatch_size
is any integer (for all models exceptImageClassifier
); or - for any
x <: Array{<:Float32, 4}
of size(W, H, n_channels, batch_size)
, where(W, H) = n_in
,n_channels
is 1 or 3, andbatch_size
is any integer (for use withImageClassifier
)
- for any
The object returned by
chain(x)
must be anAbstractFloat
vector of lengthn_out
.
Alternatively, use MLJFlux.@builder(neural_net)
to automatically create a builder for any valid Flux chain expression neural_net
, where the symbols n_in
, n_out
, n_channels
and rng
can appear literally, with the interpretations explained above. For example,
builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh))